// This file is part of Module Proxy.

// Module Proxy is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.

// Module Proxy is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU General Public License for more details.

// You should have received a copy of the GNU General Public License
// along with Module Proxy.  If not, see <https://www.gnu.org/licenses/>.


//         Copyright (C) 2021 - 2030  关中麦客  
//         All rights reserved
//
//         mod_uuid.rs
//         uuid鉴权
//
//         created by 关中麦客 1036038462@qq.com

use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener};
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use tokio::net::TcpStream;
use std::time::Duration;
use tokio::time::sleep;
use hyper::{Body, Request, Response, Error};
use hyper::header::{HeaderMap, HeaderValue};

use crate::mod_uuid;
use crate::mod_uuidforward;

const SOCKET_OK: &str = "OK";
const SOCKET_ERROR: &str = "ERROR";

/// 初始化，由main.rs调用
pub async fn init()
{
    //uuid鉴权非使能（被注释）
    if !super::conf::enable_module_uuid_auth()
    {
        return;  
    }

    log::info!("mod_uuid_auth init...");

    //配置文件中有uuid鉴权的配置（未注释）
    if let Some(port) = super::conf::module_uuid_port()
    {
        // TCP 侦听------------
        let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port);
        let listener = TcpListener::bind(addr).await.unwrap();

        log::info!("mod_uuid_auth listening port: {}", port);

        tokio::spawn(async move {
            loop 
            {
                // accept
                let (socket, _) = listener.accept().await.unwrap();
                process(socket).await;
            }
        });
    }

    mod_uuid::release_thread().await;
    mod_uuidforward::init_cluster();
}

/// 释放线程，由main.rs调用
pub async fn release_thread()
{
    tokio::spawn(async move {
        loop
        {
            //休眠30秒
            sleep(Duration::from_millis(30_000)).await;

            let list = super::mod_uuidmap::get_exp().await; //获得所有失效的uuid
            for uuid in list
            {
                super::mod_uuidmap::del(uuid).await; //删除
            }
        }
    });
}

/// 模块名是否匹配mask
pub fn exist(module_name: &str) -> bool
{
    //uuid鉴权非使能（被注释）
    if !super::conf::enable_module_uuid_auth()
    {
        return false;  
    }

    let req_v: Vec<&str> = module_name.split("-").collect();    //模块名结构
    let mask = super::conf::module_uuid_mask().unwrap();
    let mask_v: Vec<&str>= mask.split("-").collect();           //mask结构

    if req_v.len() != mask_v.len()
    {
        return false;       //结构段数量不一致
    }

    for i in 0..req_v.len()
    {
        if req_v[i].len() != mask_v[i].len()
        {
            return false;   //某段长度不一致
        }
    }

    true
}

/// 转发
pub async fn forward(mut request: Request<Body>, uuid: &str) -> Result<Response<Body>, Error>
{
    log::debug!("[mod_uuid] forward uuid: {}", uuid);
    // 在header中添加 mod_proxy_uuid -------------
    let header_map: &mut HeaderMap<HeaderValue> = request.headers_mut();
    loop
    {
        if None == header_map.remove("mod_proxy_uuid")  //删除已有的（避免客户端伪装header）
        {
            break;
        }
    }
    log::debug!("clear header mod_proxy_uuid");
    //添加header
    if let Some((val, _timeout)) = super::mod_uuidmap::get(uuid).await
    {
        log::debug!("[mod_uuid] mod_proxy_uuid val: {}", val);
        header_map.insert("mod_proxy_uuid", val.parse().unwrap());
        //重新置入uuid缓存（更新失效时间）
        let exptime = super::util::sec_timestamp() + super::conf::module_uuid_exptime().unwrap();
        super::mod_uuidmap::set(String::from(uuid), val, exptime).await; 
    }

    // 去除uri中的UUID-------------------
    // 例如 /UUID/aaa/bbb/ccc?name=valu
    let old_uri = request.uri().path_and_query().map(|x| x.as_str()).unwrap_or("/");
    //url中去除UUID部分, 保留后面的 /aaa/bbb/ccc?name=valu
    let remove_str = format!("/{}", uuid);
    let new_uri = old_uri.replace(&remove_str, "");

    //新uri置入源request 
    match new_uri.parse::<hyper::Uri>()
    {
        Ok(uri) => 
        {
            *request.uri_mut() = uri;
        },
        Err(err) => 
        {
            log::error!("{} - uuid uri error: {}", new_uri, err);
            return Ok(super::response::rsp_500().await);
        },
    }

    //http 转发 --------------------
    if let Some(second_mod_name) = super::module(&new_uri) //这里的second_mod_name是uuid后面的
    {
        if super::mod_uuidforward::exist(&second_mod_name) //配置中存在第二段模块名称
        {
            return mod_uuidforward::forward(request, &second_mod_name).await;
        }
    }

    return Ok(super::response::rsp_500().await);
}

/// 处理socket
async fn process(mut socket: TcpStream)
{
    let mut buf: [u8; 1024] = [0; 1024];                //缓冲
    if let Ok(c_size) = socket.read(&mut buf).await     //socket读
    {
        //[u8]转&str
        match std::str::from_utf8(&buf[..c_size]) 
        {
            Ok(msg) =>
            {
                if let Some((uuid, val)) = parse_socket(msg) //解析消息
                {
                    let exptime = super::util::sec_timestamp() + super::conf::module_uuid_exptime().unwrap();
                    super::mod_uuidmap::set(uuid, val, exptime).await; //置入uuid缓存
                    let _ = socket.write_all(SOCKET_OK.as_bytes()).await; //socket返回成功
                    return;
                }
            }
            Err(err) =>
            {
                log::warn!("[mod_uuid] socket read msg from [u8] to &str error: {}", err);
                return;
            }
        }
    }

    let _ = socket.write_all(SOCKET_ERROR.as_bytes()).await; //socket返回失败
}

/// 从socket中获得uuid和val
pub fn parse_socket(msg: &str) -> Option<(String, String)>
{
    //mask
    let mask = super::conf::module_uuid_mask().unwrap();

    //消息长度必须大于mask长度
    if msg.len() <= mask.len()
    {
        return None;
    }

    //查询出module名称
    for (i, c) in msg.char_indices() 
    {
        if c == ':' && i == mask.len()  //找到第一个':'且位置正确
        {
            let uuid = &msg[0..i];  
            let val = &msg[i+1..];          
              
            return Some(( String::from(uuid), String::from(val) ));
        }
    }
    
    None
}